using DifferentialEquations, Random, LinearAlgebra, NPZ, PyPlot

N = 8
JdiagFactor = 100

# Make an SK matrix, with 0.5 weight for K like in the parallel tempering.
KWeight = 0.5
Jrand = randn(N,N)
Jrand = (Jrand + Jrand')/(2*sqrt(N))
Krand = randn(N,N)
Krand = (Krand + Krand')/(2*sqrt(N))
JK_ = Jrand .+ im*Krand  

# Make the Gell-Mann matrices
λ1 =    [0 1 0;
         1 0 0;
         0 0 0]

λ2 =    [0 -im 0;
         im 0 0;
         0 0 0]

λ3 =    [1 0 0;
         0 -1 0;
         0 0 0]

λ4 =    [0 0 1;
         0 0 0;
         1 0 0]

λ5 =    [0 0 -im;
         0 0 0;
         im 0 0]

λ6 =    [0 0 0;
         0 0 1;
         0 1 0]

λ7 =    [0 0 0;
         0 0 -im;
         0 im 0]

λ8 =    float.([1 0 0;
                0 1 0;
                0 0 -2]) ./ sqrt(3)

λs = [λ1,λ2,λ3,λ4,λ5,λ6,λ7,λ8];

function mainMF(JdiagFactor,Jload)

    # Extract system size from J matrix 
    N = size(Jload)[1]

    # Parameters
    M = Int(2.3e5)
    ωz = 2π * 4e-3 # MHz
    Δc = 2π * 60 # MHz 
    κ = 2π * 137e-3 # MHz
    holdT = 0e3 # μs

    # Integration parameters
    solver = TsitPap8()
    dt = 10 # us
    tspan = (0,2000+holdT) # us, 2000 for experimental setting
    nSteps = Int( ceil( (tspan[2]-tspan[1])/dt ) )
    abstol = 1e-9 #1e-15 # Default is 1e-6
    reltol = 1e-9 # Default is 1e-3
    maxiters = 1e6

    # Observables
    nObs = Int(round(tspan[2]/10))
    stepsPerMeasure = Int(ceil(  ((tspan[2]-tspan[1])/nObs)/dt ))
    tvals = zeros(nObs)
    xt = zeros(8,N,nObs)

    # normalize J by HS norm 
    norm =  sqrt(sum(eigen(real.(Jload)).values.^2)) 
    Jload ./= norm

    # Compute average strength of an element
    avgVal = sum( abs.(real.(Jload)) )/N^2
    
    # Make Jc and Js
    Jloc = JdiagFactor*avgVal
    Jc = Jloc*I + real.(Jload) # 8(2.6) for 7/7 data, 12(4) for 7/14 data
    Js = Jloc*I - real.(Jload)
    K = imag.(Jload)
    if minimum(eigen(Jc).values)<0 || minimum(eigen(Js).values)<0
        println("Warning! Negative eigenvalues detected. Use a bigger diagonal component.")
    end

    # Normalize the J matrices
    JBig = zeros(2N,2N)
    JBig[1:N,1:N] .= Jc 
    JBig[N+1:2N,N+1:2N] .= Js 
    JBig[N+1:2N,1:N] .= K
    JBig[1:N,N+1:2N] .= K
    λmax = maximum(real.(eigen(JBig).values))
    Jc /= λmax
    Js /= λmax
    K  /= λmax

    # Save the eigenvectors and eigenvalues
    eigJc = eigen(Jc)
    evecsC = eigJc.vectors
    evalsC = eigJc.values
    eigJs = eigen(Js)
    evecsS = eigJs.vectors
    evalsS = eigJs.values

    # Pump schedule
    function g(t) 
        if t < 1.5e3
            return t/1e3 * ωz/2 *1.25 
        elseif t < 1.5e3 + holdT
            return ωz/2*1.25 * 1.5
        else 
            return ωz/2*1.25 * 2.5
        end
    end

    # Create the index table
    tb = Int.(zeros(8,N))
    cnt = 1
    for i=1:N
        for α=1:8 
            tb[α,i] = cnt 
            cnt += 1
        end
    end
    
    # Initial conditions
    ψ0 = zeros(8*N)
    for i=1:N
        # Normal state with noise
        for α=1:8
            state = [1; rand()*0.01*exp(im*2π*rand()); rand()*0.01*exp(im*2π*rand())]
            norm = real(state'*state)
            ψ0[tb[α,i]] = real( state'*λs[α]*state )/norm
        end
    end

    # Useful parameter combinations
    ωz2 = 2*ωz
    sqrt3 = sqrt(3)
    Tc = zeros(N)
    Ts = zeros(N)
    
    # Derivative
    function deriv!(du,u,p,t)

        gt = g(t)

        for i=1:N

            # Compute T matrices
            Tc[i]=0
            Ts[i]=0
            for j=1:N 
                Tc[i] += Jc[i,j]*u[tb[1,j]] + K[i,j]*u[tb[4,j]]
                Ts[i] += Js[i,j]*u[tb[4,j]] + K[i,j]*u[tb[1,j]]
            end

            # Derivative terms
            du[tb[1,i]] = ωz2*u[tb[2,i]] -2*gt*( Ts[i]*u[tb[7,i]] )

            du[tb[2,i]] = -ωz2*u[tb[1,i]] +2*gt*( 2*Tc[i]*u[tb[3,i]] - Ts[i]*u[tb[6,i]] )

            du[tb[3,i]] = -2*gt*( 2*Tc[i]*u[tb[2,i]] + Ts[i]*u[tb[5,i]] )

            du[tb[4,i]] = ωz2*u[tb[5,i]] + 2*gt*( Tc[i]*u[tb[7,i]] )

            du[tb[5,i]] = -ωz2*u[tb[4,i]] - 2*gt*( Tc[i]*u[tb[6,i]] - Ts[i]*(u[tb[3,i]] + sqrt3*u[tb[8,i]]) )

            du[tb[6,i]] = 2*gt*( Tc[i]*u[tb[5,i]] + Ts[i]*u[tb[2,i]] )

            du[tb[7,i]] = -2*gt*( Tc[i]*u[tb[4,i]] - Ts[i]*u[tb[1,i]] )

            du[tb[8,i]] = -sqrt3*2*gt*( Ts[i]*u[tb[5,i]] )

        end

    end
    
    # Initialize ODE functions
    fun = ODEFunction(deriv!) #; jac=jac!, jac_prototype=jac0)
    prob = ODEProblem(fun,ψ0,tspan)
    integrator = init(prob,solver,dt=dt,save_on=false,abstol=abstol,reltol=reltol,save_everystep=false,maxiters=maxiters)

    # Noise for cosine quadrature
    dWc = randn(N,nSteps) # Unit variance normal variables
    dWc .*= sqrt(dt/2) # Now dW[n+1]-dW[n] has variance dt.

    # Noise for sine quadrature
    dWs = randn(N,nSteps) # Unit variance normal variables
    dWs .*= sqrt(dt/2) # Now dW[n+1]-dW[n] has variance dt.

    # Noise steps 
    dSc = zeros(N,nSteps)
    dSs = zeros(N,nSteps)
    for t=1:nSteps 
        for k=1:N
            dSc[:,t] .+= sqrt(2*κ/Δc*evalsC[k]) * dWc[k,t] * evecsC[:,k]
            dSs[:,t] .+= sqrt(2*κ/Δc*evalsS[k]) * dWs[k,t] * evecsS[:,k]
        end
    end
    
    # Do it
    idxMeas = 1
    ψrand = copy(ψ0)
    for step=1:nSteps
        
        # Deterministic step
        step!(integrator,dt,true)

        # Noisy Hamiltonian
        sqrtg = sqrt(g(integrator.t))
        for i=1:N 
            ψrand[tb[1,i]] = integrator.u[tb[7,i]]*dSs[i,step]

            ψrand[tb[2,i]] = -2*integrator.u[tb[3,i]]*dSc[i,step] + integrator.u[tb[6,i]]*dSs[i,step] 

            ψrand[tb[3,i]] = 2*integrator.u[tb[2,i]]*dSc[i,step] + integrator.u[tb[5,i]]*dSs[i,step] 

            ψrand[tb[4,i]] = -integrator.u[tb[7,i]]*dSc[i,step] 

            ψrand[tb[5,i]] = integrator.u[tb[6,i]]*dSc[i,step] - (integrator.u[tb[3,i]] + sqrt3*integrator.u[tb[8,i]] )*dSs[i,step] 

            ψrand[tb[6,i]] = -integrator.u[tb[5,i]]*dSc[i,step] - integrator.u[tb[2,i]]*dSs[i,step] 

            ψrand[tb[7,i]] = integrator.u[tb[4,i]]*dSc[i,step] - integrator.u[tb[1,i]]*dSs[i,step] 

            ψrand[tb[8,i]] = sqrt3*integrator.u[tb[5,i]]*dSs[i,step] 
        end
        integrator.u .+= sqrtg .* ψrand ./sqrt(M) 

        # Renormalize.
        for i=1:N 
            norm = sqrt(sum(integrator.u[tb[:,i]].^2)*3/4)
            integrator.u[tb[:,i]] ./= norm
        end
        
        # Measure
        if mod(step-1,stepsPerMeasure)==0
            tvals[idxMeas]  = integrator.t 

            for i=1:N
                for α=1:8
                    xt[α,i,idxMeas] = integrator.u[tb[α,i]]
                end
            end

            idxMeas += 1
        end
        
    end
    
    observables = (tvals,xt)
    
    return observables,Jc,Js,K,g

end

#####################################################

# Run the simulation
observables,Jc,Js,K,g = mainMF(JdiagFactor,JK_)
tvals,xt = observables;
_,_,nObs = size(xt)

#######################################################

# Analysis

# J matrices
Jdiag = (Jc + Js) ./ 2
Jnon = Jc/2 - Js/2

# Calculate vector angles
θs = zeros(N,nObs)
for i=1:N 
    θs[i,:] .= mod.( atan.( xt[4,i,:], xt[1,i,:] ) , 2π)
end

# Calculate interaction energy
Es = zeros(nObs)
for t=1:nObs 
    for i=1:N 
        for j=1:N 
            Es[t] += -Jnon[i,j]*cos(θs[i,t] + θs[j,t]) - K[i,j]*sin(θs[i,t] + θs[j,t])
        end
    end
end

# Integrated vector components during readout
cmplxθs = exp.(im*θs)
states = zeros(N)
for i=1:N
    idxs = tvals .> tvals[end]-500
    states[i] = mod( atan( sum(xt[4,i,idxs]) , sum(xt[1,i,idxs]) ) , 2π)/π
end
println("Readout spin config:")
println(states)

# Total spin
totalSpin = zeros(N,nObs)
for i=1:N
    for α=1:8 
        totalSpin[i,:] .+= xt[α,i,:].^2
    end
end

######################################################

# Vector components for a single spin
fig = figure()
for α=1:8

    if α==1
        plot(tvals*1e-3,xt[α,1,:],label="Cosine DW",color="tab:blue",linewidth=2)
    elseif α==4
        plot(tvals*1e-3,xt[α,1,:],label="Sine DW",color="tab:red")
    elseif α==3
        continue
    elseif α==8
        plot(tvals*1e-3, (xt[3,1,:].+sqrt(3)*xt[8,1,:])/2 , label="Normal state", color="tab:green")
    else
        plot(tvals*1e-3,xt[α,1,:],color="black",alpha=0.2)
    end

end
plot(tvals*1e-3,g.(tvals)/g(tvals[end]) * 2.5,linestyle="--",color="black",label="Pump")
xlabel("Time [ms]",fontsize=14)
ylabel(L"\langle \Lambda^\alpha \rangle",fontsize=14)
# ylim(-0.51,0.51)
legend(loc="upper left")
PyPlot.display_figs()

# Selected spin components for a single spin
fig = figure()
plot(tvals*1e-3,(xt[3,1,:].+sqrt(3)*xt[8,1,:]),label="no DW component")
plot(tvals*1e-3,xt[1,1,:],label="cosine DW component")
plot(tvals*1e-3,xt[4,1,:],label="sine DW component")
xlabel("Time [ms]",fontsize=14)
# ylim(-1.05,2.05)
# axvline(1.2,label="Estimate of phase transition",alpha=0.25,color="black")
legend(loc=(1.01,0))
PyPlot.display_figs()

# θs
fig = figure()
colors = ["tab:blue","tab:orange","tab:green","tab:red","tab:purple","tab:brown","tab:pink","tab:gray"]
for i=1:N
    plot(tvals*1e-3,θs[i,:]./π,".",color=colors[1+mod(i,8)],label="Spin "*string(i))
    plot(tvals*1e-3,θs[i,:]./π .- 2,".",color=colors[1+mod(i,8)])
    plot(tvals*1e-3,θs[i,:]./π .+ 2,".",color=colors[1+mod(i,8)])
end
ylim(-0.1,2.1)
xlabel("Time [ms]",fontsize=14)
ylabel("Vector angle [π rad]",fontsize=14)
# legend(loc=(1.01,0))
PyPlot.display_figs()

# Interaction energy 
fig = figure()
plot(tvals*1e-3,Es)
xlabel("Time [ms]",fontsize=14)
ylabel("Interaction energy",fontsize=14)
# ylim(-1.1,0.1)
PyPlot.display_figs()


